{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "import time\n",
    "import tqdm\n",
    "import itertools\n",
    "from allennlp.nn import util"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chinese_gpt import TransformerEncoder as Encoder\n",
    "from chinese_gpt import TransformerDecoderLM as Decoder\n",
    "from pytorch_pretrained_bert import BertTokenizer, OpenAIAdam"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# You need this library for beam search\n",
    "from allennlp.nn.beam_search import BeamSearch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def top_k_logits(logits, k):\n",
    "    values, _ = torch.topk(logits, k)\n",
    "    min_values = values[:, -1].unsqueeze(1).repeat(1, logits.shape[-1])\n",
    "    return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You need to define a step function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "temperature = 1.0\n",
    "\n",
    "def take_step(last_predictions, state):\n",
    "    \n",
    "    past = state['past'].permute(1, 2, 0, 3, 4, 5)\n",
    "    past = [past[i] for i in range(12)]\n",
    "    past_length = state[\"length\"][0].item()\n",
    "    \n",
    "    logits, past = decoder(last_predictions.view(-1, 1), \n",
    "                                  state[\"mask\"], \n",
    "                                  past=past, \n",
    "                                  past_length=past_length)\n",
    "\n",
    "    logits = logits.squeeze(1) / temperature\n",
    "    log_probs = F.log_softmax(logits, dim=-1)\n",
    "    \n",
    "    state[\"mask\"] = F.pad(state[\"mask\"], (0, 1), \"constant\", 1.0)\n",
    "    state[\"past\"] = torch.stack(past, 0).permute(2, 0, 1, 3, 4, 5)\n",
    "    state[\"length\"] += 1\n",
    "    \n",
    "    return log_probs, state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = BertTokenizer.from_pretrained(\"bert-base-chinese\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "beam_search = BeamSearch(end_index=102, max_steps=40, beam_size=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 140,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = Encoder()\n",
    "decoder = Decoder()\n",
    "encoder.eval()\n",
    "decoder.eval()\n",
    "\n",
    "# pretrained weights\n",
    "encoder.load_state_dict(torch.load(\"5encoder.pth\"))\n",
    "decoder.load_state_dict(torch.load(\"5decoder.pth\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load validation dataset\n",
    "val_data = torch.load(\"val_data.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 142,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 1\n",
    "val_dataset = TensorDataset(*val_data)\n",
    "val_dataloader = DataLoader(dataset=val_dataset, shuffle=False, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\")\n",
    "\n",
    "encoder = encoder.to(device)\n",
    "decoder = decoder.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Beam Search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 307,
   "metadata": {},
   "outputs": [],
   "source": [
    "loader = iter(val_dataloader)\n",
    "\n",
    "for i in range(530):\n",
    "    next(loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 308,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1279c151df184404bc38e7670933fa88",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=5161), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated beam-1:\n",
      "终于找到火箭惨败的真因了！别光怪主力不行，这才是最大祸根[SEP][SEP][SEP][SEP]\n",
      "Generated beam-2:\n",
      "终于找到火箭惨败的真因了！别光怪主力不行，这1点才是最大祸首[SEP][SEP]\n",
      "原标题+summarization:\n",
      "[CLS]引爆火箭交易的不是争冠，而是主力要被累死了[SEP]休斯顿火箭队感恩节之前表现不错，五连胜终于打出了上赛季的火热状态，似乎新赛季一切都随着时间的推移变得好了起来。不仅被活塞队复仇，还输给了东部鱼腩骑士队，面对着矛盾丛生的奇才队也以失利告终。我们经常在比赛的开局阶段表现得很差，经常会打得非常软，从一开始就注定要输球\"。[SEP]\n",
      "运营:\n",
      "[CLS]势在必行！火箭交易只因1.6亿花得不值，这2人成全队最大障碍[SEP]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "all_outputs = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for batch in tqdm.tqdm_notebook(loader):\n",
    "        batch = [item.to(device) for item in batch]\n",
    "\n",
    "        encoder_input, \\\n",
    "                third, \\\n",
    "                mask_encoder_input, \\\n",
    "                mask_third, \\\n",
    "                encoder_type_ids, \\\n",
    "                third_type_ids = batch\n",
    "\n",
    "        _, past = encoder(encoder_input, mask_encoder_input, encoder_type_ids)\n",
    "\n",
    "        state = {}\n",
    "        start_predictions = torch.LongTensor([[101]]* batch_size).to(device)\n",
    "        mask = torch.ones(batch_size, start_predictions.shape[1]).to(device)\n",
    "        mask = torch.cat([mask_encoder_input.float(), mask], dim=1)\n",
    "        state[\"mask\"] = mask\n",
    "        state[\"length\"] = torch.LongTensor([[0]] * batch_size).to(device)\n",
    "        state[\"past\"] = torch.stack(past, 0).permute(2, 0, 1, 3, 4, 5)\n",
    "\n",
    "        all_top_k_predictions, log_probabilities = beam_search.search(start_predictions, state, take_step)\n",
    "\n",
    "\n",
    "   \n",
    "        all_outputs.append(\"\".join(tokenizer.convert_ids_to_tokens(all_top_k_predictions[0][0].tolist())\n",
    "                      ).replace(\"##\", \"\"))\n",
    "        print(\"Generated beam-1:\")\n",
    "        print(\"\".join(tokenizer.convert_ids_to_tokens(all_top_k_predictions[0][0].tolist())\n",
    "             ).replace(\"##\", \"\"))\n",
    "        print(\"Generated beam-2:\")\n",
    "        print(\"\".join(tokenizer.convert_ids_to_tokens(all_top_k_predictions[0][-1].tolist())\n",
    "                     ).replace(\"##\", \"\"))\n",
    "        print(\"原标题+summarization:\")\n",
    "        print(\"\".join(tokenizer.convert_ids_to_tokens(encoder_input[0].tolist())\n",
    "                     ).replace(\"##\", \"\").replace(\"[PAD]\", \"\"))\n",
    "        print(\"运营:\")\n",
    "        print(\"\".join(tokenizer.convert_ids_to_tokens(third[0].tolist())\n",
    "                     ).replace(\"##\", \"\").replace(\"[PAD]\", \"\"))\n",
    "        \n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_outputs = [item.replace(\"[SEP]\", \"\") for item in all_outputs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"results\", \"w\") as f:\n",
    "    for line in all_outputs:\n",
    "        f.write(line)\n",
    "        f.write(\"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sampling Based Search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 309,
   "metadata": {},
   "outputs": [],
   "source": [
    "loader = iter(val_dataloader)\n",
    "\n",
    "for i in range(530):\n",
    "    batch = next(loader)\n",
    "\n",
    "batch = next(loader)\n",
    "batch = [item.to(device) for item in batch]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 326,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated sampled\n",
      "1.39405121091997e-11\n",
      "终于找到火箭失利的真因了！主力要被累死了，这才是最根本祸根[SEP]终结[SEP]终结[SEP]终场哨响[SEP]\n",
      "原标题+summarization:\n",
      "[CLS]引爆火箭交易的不是争冠，而是主力要被累死了[SEP]休斯顿火箭队感恩节之前表现不错，五连胜终于打出了上赛季的火热状态，似乎新赛季一切都随着时间的推移变得好了起来。不仅被活塞队复仇，还输给了东部鱼腩骑士队，面对着矛盾丛生的奇才队也以失利告终。我们经常在比赛的开局阶段表现得很差，经常会打得非常软，从一开始就注定要输球\"。[SEP]\n",
      "运营:\n",
      "[CLS]势在必行！火箭交易只因1.6亿花得不值，这2人成全队最大障碍[SEP]\n"
     ]
    }
   ],
   "source": [
    "length = 0\n",
    "top_k = 10\n",
    "\n",
    "total_prob = 0.0\n",
    "\n",
    "encoder_input, \\\n",
    "        third, \\\n",
    "        mask_encoder_input, \\\n",
    "        mask_third, \\\n",
    "        encoder_type_ids, \\\n",
    "        third_type_ids = batch\n",
    "\n",
    "_, past = encoder(encoder_input, mask_encoder_input, encoder_type_ids)\n",
    "\n",
    "start_predictions = torch.LongTensor([[101]]* batch_size).to(device)\n",
    "mask = torch.ones(batch_size, start_predictions.shape[1]).to(device)\n",
    "mask = torch.cat([mask_encoder_input.float(), mask], dim=1)\n",
    "\n",
    "logits, past = decoder(start_predictions, mask, past=past, past_length=0)\n",
    "logits = logits.squeeze(1) / 1.0\n",
    "logits = top_k_logits(logits, k=top_k)\n",
    "\n",
    "sentence = []\n",
    "\n",
    "probs = F.softmax(logits, dim=-1)\n",
    "prob, prev_pred = torch.topk(probs, k=1, dim=-1)\n",
    "sentence.append(prev_pred)\n",
    "length += 1\n",
    "total_prob += np.log(prob.item())\n",
    "\n",
    "for i in range(40):\n",
    "    mask = F.pad(mask, (0, 1), \"constant\", 1.0)\n",
    "    logits, past = decoder(prev_pred, mask, past=past, past_length=length)\n",
    "    logits = logits.squeeze(1) / 1.0\n",
    "    logits = top_k_logits(logits, k=top_k)\n",
    "    probs = F.softmax(logits, dim=-1)\n",
    "    prev_pred = torch.multinomial(probs, num_samples=1)\n",
    "    sentence.append(prev_pred)\n",
    "    length += 1\n",
    "    total_prob += np.log(probs[0, prev_pred.item()].item())\n",
    "    \n",
    "sentence = torch.cat(sentence, dim=-1)\n",
    "\n",
    "print(\"Generated sampled\")\n",
    "print(np.exp(total_prob))\n",
    "print(\"\".join(tokenizer.convert_ids_to_tokens(sentence[0].tolist())\n",
    "             ).replace(\"##\", \"\"))\n",
    "print(\"原标题+summarization:\")\n",
    "print(\"\".join(tokenizer.convert_ids_to_tokens(encoder_input[0].tolist())\n",
    "             ).replace(\"##\", \"\").replace(\"[PAD]\", \"\"))\n",
    "print(\"运营:\")\n",
    "print(\"\".join(tokenizer.convert_ids_to_tokens(third[0].tolist())\n",
    "             ).replace(\"##\", \"\").replace(\"[PAD]\", \"\"))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
